#----------------------------------------------
# -*- encoding=utf-8 -*-                      #
# __author__:'xiaojie'                        #
# CreateTime:                                 #
#       2019/7/1 9:28                       #
#                                             #
#               天下风云出我辈，                 #
#               一入江湖岁月催。                 #
#               皇图霸业谈笑中，                 #
#               不胜人生一场醉。                 #
#----------------------------------------------
# https://github.com/bojone/vae/blob/master/vae_keras_cluster.py

import numpy as np
from scipy import misc
from keras.layers import *
from keras.models import Model
from keras import backend as K
import os
from keras.datasets import mnist
from  keras.utils import plot_model
# from keras.datasets import fashion_mnist as mnist

batch_size = 100
latent_dim = 20
epochs = 50
num_classes = 10
img_dim = 28
filters = 16
intermediate_dim = 256

# 加载MNIST数据集
(x_train, y_train_), (x_test, y_test_) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((-1, img_dim, img_dim, 1))
x_test = x_test.reshape((-1, img_dim, img_dim, 1))

#搭建模型
x = Input(shape=(img_dim,img_dim,1))
h = x

for i in range(2):
    filters *=2
    h = Conv2D(filters=filters,kernel_size=3,strides=2,padding='same')(h)
    h = LeakyReLU(0.2)(h)
    h = Conv2D(filters = filters,kernel_size=3,strides=1,padding='same')(h)
    h = LeakyReLU(0.2)(h)

h_shape = K.int_shape(h)[1:]
h=Flatten()(h)
z_mean = Dense(latent_dim)(h)#p(z|x)的均值
z_log_var = Dense(latent_dim)(h) #p(z|x)的方差

encoder = Model(x,z_mean) #通常认为z_mean就是隐变量的编码

z = Input(shape=(latent_dim,))
h = z
h = Dense(np.prod(h_shape))(h)
h= Reshape(h_shape)(h)

for i in range(2):
    h = Conv2DTranspose(filters=filters,
                        kernel_size=3,
                        strides=1,
                        padding='same')(h)
    h = LeakyReLU(0.2)(h)
    h = Conv2DTranspose(filters=filters,
                        kernel_size=3,
                        strides=2,
                        padding='same')(h)
    h = LeakyReLU(0.2)(h)
    filters //= 2

x_recon = Conv2DTranspose(filters=1,kernel_size=3,activation='sigmoid',padding='same')(h)

decoder = Model(z,x_recon,name='decoder') #解码器
plot_model(decoder,to_file='png/decoder.png',show_shapes=True)
generator = decoder

z = Input(shape=(latent_dim,))
y = Dense(intermediate_dim,activation='relu')(z)
y = Dense(num_classes,activation='softmax')(y)

classfier = Model(z,y) #隐变量分类器
plot_model(classfier,show_shapes=True,to_file='png/classifier.png')

# 重参数技巧
def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim))
    return z_mean + K.exp(z_log_var / 2) * epsilon

# 重参数层，相当于给输入加入噪声
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
x_recon = decoder(z)
y = classfier(z)

class Gaussian(Layer):
    """
    这是个简单的层，定义q(z|y)的均值参数，每个类别配一个均值。
    然后输出"z-均值"，为后面计算loss准备
    """
    def __init__(self,num_classes,**kwargs):
        self.num_classes = num_classes
        super(Gaussian,self).__init__(**kwargs)

    def build(self,input_shape):
        latent_dim = input_shape[-1]
        self.mean = self.add_weight(name='mean',
                                    shape = (self.num_classes,latent_dim),
                                    initializer='zeros')

    def call(self,inputs):
        z = inputs # z.shape = (batch_size,latent_dim)
        print('Gaussian.input.shape',z.shape)
        z = K.expand_dims(z,1)
        return z-K.expand_dims(self.mean,0)

    def compute_output_shape(self, input_shape):
        return (None,self.num_classes,input_shape[-1])

gaussian = Gaussian(num_classes)
z_prior_mean = gaussian(z)
print('z.shape',z.shape)
print('z_prior_mean.shape',z_prior_mean.shape)

# 建立模型
vae = Model(x,[x_recon,z_prior_mean,y])
plot_model(vae,show_shapes=True,to_file='png/vae_cluster.png')

# 下面的一大通都是为了定义loss
z_mean = K.expand_dims(z_mean,1)
z_log_var = K.expand_dims(z_log_var,1)

lamb = 2.5 #这是重构误差的权重，它的相反数就是重构误差，越大意味着方差越小。
xent_loss = 0.5*K.mean((x-x_recon)**2,0)
kl_loss = -0.5*(z_log_var-K.square(z_prior_mean))
kl_loss = K.mean(K.batch_dot(K.expand_dims(y,1),kl_loss),0)
cat_loss = K.mean(y*K.log(y+K.epsilon()),0)
vae_loss = lamb*K.sum(xent_loss)+K.sum(kl_loss)+K.sum(cat_loss)

vae.add_loss(vae_loss)
vae.compile(optimizer='adam')
# vae.summary()

vae.fit(x_train,shuffle=True,epochs=epochs,batch_size=batch_size,validation_data=(x_test,None))

means = K.eval(gaussian.mean)
x_train_encoded = encoder.predict(x_train)
y_train_pred = classfier.predict(x_train_encoded).argmax(axis=1)
x_test_encoded = encoder.predict(x_test)
y_test_pred = classfier.predict(x_test_encoded).argmax(axis=1)

def cluster_sample(path, category=0):
    """观察被模型聚为同一类的样本
    """
    n = 8
    figure = np.zeros((img_dim * n, img_dim * n))
    idxs = np.where(y_train_pred == category)[0]
    for i in range(n):
        for j in range(n):
            digit = x_train[np.random.choice(idxs)]
            digit = digit.reshape((img_dim, img_dim))
            figure[i * img_dim: (i + 1) * img_dim,
            j * img_dim: (j + 1) * img_dim] = digit
    misc.imsave(path, figure * 255)

def random_sample(path, category=0, std=1):
    """按照聚类结果进行条件随机生成
    """
    n = 8
    figure = np.zeros((img_dim * n, img_dim * n))
    for i in range(n):
        for j in range(n):
            noise_shape = (1, latent_dim)
            z_sample = np.array(np.random.randn(*noise_shape)) * std + means[category]
            x_recon = generator.predict(z_sample)
            digit = x_recon[0].reshape((img_dim, img_dim))
            figure[i * img_dim: (i + 1) * img_dim,
            j * img_dim: (j + 1) * img_dim] = digit
    misc.imsave(path, figure * 255)

if not os.path.exists('samples'):
    os.mkdir('samples')

for i in range(10):
    cluster_sample(u'samples/聚类类别_%s.png' % i, i)
    random_sample(u'samples/类别采样_%s.png' % i, i)

right = 0.
for i in range(10):
    _ = np.bincount(y_train_[y_train_pred == i])
    right += _.max()

print ('train acc: %s' % (right / len(y_train_)))

right = 0.
for i in range(10):
    _ = np.bincount(y_test_[y_test_pred == i])
    right += _.max()

print ('test acc: %s' % (right / len(y_test_)))